import random
import torch
from sklearn.model_selection import train_test_split
from .hogrl_mode_dual import *
from .hogrl_utils import *
# from .hogrl_utils import *
import numpy as np
import random as rd
from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, average_precision_score, silhouette_score
from torch_geometric.utils import degree, to_undirected, dropout_adj
import numpy as np
import os
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.decomposition import PCA
import time
from sklearn.manifold import TSNE
import torch.nn as nn
import torch.nn.functional as F
import gc
def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()
class GradientConfidenceAware(nn.Module):
    def __init__(self, num_classes, k_percent=10, gamma_focal=0.8, gamma_ga=0.5, use_softmax=True):
        super(GradientConfidenceAware, self).__init__()
        self.num_classes = num_classes
        self.k_percent = k_percent
        self.gamma_focal = gamma_focal * 2.5
        self.gamma_ga = gamma_ga
        self.use_softmax = use_softmax
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_weights', torch.ones(num_classes))

    def forward(self, inputs, targets):
        B, C = inputs.shape[:2]
        N = inputs.shape[2:].numel() * B  # Total number of samples

        # 1. Calculate probabilities and base loss
        probs = F.softmax(inputs, dim=1) if self.use_softmax else inputs
        probs = probs.permute(0, *range(2, inputs.dim()), 1).contiguous().view(-1, C)
        targets = targets.view(-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        ce_loss = -torch.log(pt + 1e-8)

        # 2. Enable gradient computation (key step!)
        inputs_grad = inputs.detach().requires_grad_(True)  # Retain gradient computation graph
        probs_grad = F.softmax(inputs_grad, dim=1) if self.use_softmax else inputs_grad
        loss_grad = F.cross_entropy(probs_grad.view(-1, C), targets, reduction='none')
        grad_outputs = torch.ones_like(loss_grad)
        gradients = torch.autograd.grad(
            outputs=loss_grad,
            inputs=inputs_grad,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=True  # Retain computation graph for subsequent backprop
        )[0]  # Gradient shape matches inputs (B, C, ...)

        # 3. Calculate gradient magnitude (L2 norm)
        gradients = gradients.permute(0, *range(2, gradients.dim()), 1).contiguous().view(-1, C)
        grad_magnitude = gradients.norm(p=2, dim=1)  # (N_total,)
        grad_weight = (grad_magnitude + 1e-8)  # Avoid zero gradients

        # 4. Dynamic class balancing (consistent with original implementation)
        num_topk = max(1, int(self.k_percent / 100 * N))
        _, topk_indices = torch.topk(ce_loss, num_topk, sorted=False)
        topk_targets = targets[topk_indices]
        current_counts = torch.bincount(topk_targets, minlength=self.num_classes).float()
        self.class_counts = 0.9 * self.class_counts + 0.1 * current_counts
        effective_counts = self.class_counts + 1e-8
        self.class_weights = (1.0 / effective_counts) ** (1.0 - self.gamma_ga)
        self.class_weights = self.class_weights / self.class_weights.sum() * C

        # 5. Triple weight coupling: Focal + Class + Gradient
        focal_weight = (1 - pt) ** self.gamma_focal
        class_weight = self.class_weights[targets]

        # Step 1: Class-aware difficulty
        difficulty_weight = class_weight * grad_weight
        difficulty_weight = difficulty_weight / (difficulty_weight.mean())

        # Step 2: Sample-level hardness (focal)
        final_weight = focal_weight * difficulty_weight
        final_weight = final_weight / (final_weight.mean())

        # 6. Final loss
        loss = (final_weight * ce_loss).mean()
        return loss
    


def get_step(split: int, classes_num: int, pgd_nums: int, classes_freq: list):
    step_size = pgd_nums*0.1
    class_step = []
    for i in range(0, classes_num):
        if i < split:
            step = (classes_freq[i] / classes_freq[0]) * step_size - 1
        else:
            step = (classes_freq[i] / classes_freq[-1]) * step_size - 1
        class_step.append(round(step))
    class_step = [0 if x < 0 else x for x in class_step]
    class_step = [pgd_nums+x for x in class_step]
    return class_step

class LogitsAdversarialPerturbation(nn.Module):
    def __init__(self, num_classes=2, pgd_nums=50, alpha=0.1):
        """
        
        Args:
            num_classes
            pgd_nums
            alpha
        """
        super().__init__()
        self.num_classes = num_classes
        self.pgd_nums = pgd_nums
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss()
        
       
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_grad_mags', torch.zeros(num_classes))
        self.momentum = 0.9  
    
    def update_statistics(self, logit, y):
        with torch.no_grad():
            
            batch_counts = torch.bincount(y, minlength=self.num_classes).float()
            self.class_counts = self.momentum * self.class_counts + (1 - self.momentum) * batch_counts
            
            
            grad_mags = torch.zeros(self.num_classes, device=logit.device)
            for c in range(self.num_classes):
                class_mask = (y == c)
                n_samples = torch.sum(class_mask)
                
                if n_samples > 0:
                   
                    class_logits = logit[class_mask]
                    class_targets = y[class_mask]
                    
                    
                    ce_loss = F.cross_entropy(class_logits, class_targets, reduction='none')
                    grad_mags[c] = ce_loss.mean().item()
            
          
            self.class_grad_mags = self.momentum * self.class_grad_mags + (1 - self.momentum) * grad_mags

    def compute_adaptive_params(self, logit, y):

        with torch.no_grad():
           
            self.update_statistics(logit, y)
            
           
            total_samples = torch.sum(self.class_counts)
            class_ratios = self.class_counts / (total_samples + 1e-8)
            
           
            minority_idx = torch.argmin(class_ratios).item()
            majority_idx = 1 - minority_idx  
            
           
            imbalance_ratio = class_ratios[majority_idx] / (class_ratios[minority_idx] + 1e-8)

            imbalance_ratio_tensor = torch.tensor([imbalance_ratio], device=logit.device)
            imbalance_factor = torch.clamp(imbalance_ratio_tensor, 1.0, 10.0)
            

            grad_scale = F.softmax(self.class_grad_mags, dim=0)
            

            class_steps = torch.zeros(self.num_classes, device=logit.device, dtype=torch.long)
            class_alphas = torch.zeros(self.num_classes, device=logit.device, dtype=torch.float)
            

            max_steps = int(self.pgd_nums * 2.0)
            min_steps = max(1, int(self.pgd_nums * 0.5))
            

            for c in range(self.num_classes):

                freq_factor = torch.sqrt(1.0 / (class_ratios[c] + 1e-8))
                steps = min_steps + int((max_steps - min_steps) * freq_factor / (freq_factor + 1.0))
                class_steps[c] = steps
                

                alpha_base = self.alpha * (1.0 + grad_scale[c].item() * 2.0)  
                

                if c == minority_idx:
                    alpha = alpha_base * min(5.0, imbalance_factor.item() ** 0.5)
                else:
                    alpha = alpha_base
                    
                class_alphas[c] = alpha
        
            

            sample_steps = torch.zeros_like(y, dtype=torch.long)
            sample_alphas = torch.zeros_like(y, dtype=torch.float)
            

            for c in range(self.num_classes):
                class_mask = (y == c)
                sample_steps[class_mask] = class_steps[c]
                sample_alphas[class_mask] = class_alphas[c]
            

            with torch.enable_grad():

                logit_grad = logit.detach().clone().requires_grad_(True)
                loss = F.cross_entropy(logit_grad, y, reduction='none')

                grads = torch.autograd.grad(
                    outputs=loss.sum(),
                    inputs=logit_grad,
                    create_graph=False,
                    retain_graph=False
                )[0]

                sample_grad_norms = torch.norm(grads, p=2, dim=1)
                sample_difficulties = F.softmax(sample_grad_norms, dim=0)
                

                difficulty_scales = 0.8 + 0.7 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)

                sample_alphas = sample_alphas * difficulty_scales
                

                steps_difficulty_scales = 1.0 + 0.5 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                sample_steps = (sample_steps.float() * steps_difficulty_scales).long()
            
            return sample_steps, sample_alphas
    
    def compute_adv_sign(self, logit, y, sample_alphas):

        with torch.no_grad():
            logit_softmax = F.softmax(logit, dim=-1)
            y_onehot = F.one_hot(y, num_classes=self.num_classes)
            

            sum_class_logit = torch.matmul(
                y_onehot.permute(1, 0)*1.0, logit_softmax)
            sum_class_num = torch.sum(y_onehot, dim=0)
            

            sum_class_num = torch.where(sum_class_num == 0, 100, sum_class_num)
            mean_class_logit = torch.div(sum_class_logit, sum_class_num.reshape(-1, 1))
            

            grad = mean_class_logit - torch.eye(self.num_classes, device=logit.device)
            grad = torch.div(grad, torch.norm(grad, p=2, dim=0).reshape(-1, 1) + 1e-8)
            

            mean_class_p = torch.diag(mean_class_logit)
            mean_mask = sum_class_num > 0
            mean_class_thr = torch.mean(mean_class_p[mean_mask])
            sub = mean_class_thr - mean_class_p
            sign = sub.sign()
            

            alphas_expanded = sample_alphas.unsqueeze(1).expand(-1, self.num_classes)
            adv_logit = torch.index_select(grad, 0, y) * alphas_expanded * sign[y].unsqueeze(1)
            
            return adv_logit, sub
    
    def compute_eta(self, logit, y):

        with torch.no_grad():

            sample_steps, sample_alphas = self.compute_adaptive_params(logit, y)
            
            logit_clone = logit.clone()
            

            max_steps = torch.max(sample_steps).item()
            

            logit_steps = torch.zeros(
                [max_steps + 1, logit.shape[0], self.num_classes], device=logit.device)

            current_logit = logit.clone()
            logit_steps[0] = current_logit
            

            for i in range(1, max_steps + 1):
                adv_logit, _ = self.compute_adv_sign(current_logit, y, sample_alphas)
                current_logit = current_logit + adv_logit
                logit_steps[i] = current_logit
            

            logit_news = torch.zeros_like(logit)
            for i in range(logit.shape[0]):
                step = sample_steps[i].item()
                logit_news[i] = logit_steps[step, i]
            

            eta = logit_news - logit_clone
            
            return eta, sample_steps, sample_alphas
    
    def forward(self, models_or_logits, x=None, y=None, is_logits=False):

        if is_logits:
 
            logit = models_or_logits
        else:

            logit = models_or_logits(x)
        

        eta, sample_steps, sample_alphas = self.compute_eta(logit, y)
        

        logit_news = logit + eta
        

        loss_adv = self.criterion(logit_news, y)
        
        return loss_adv, logit, logit_news

def visualize_clustering(embeddings, pseudo_labels, high_confidence_idx, epoch, save_path, overwrite_previous=True):
    """
    
    Args:
        embeddings
        pseudo_labels
        high_confidence_idx
        epoch:
        save_path:
        overwrite_previous
    """

    save_dir = 'logs/paper_fig'
    os.makedirs(save_dir, exist_ok=True)

    tsne = TSNE(n_components=3, random_state=42)
    

    pseudo_labels_np = pseudo_labels.cpu().numpy()
    

    if high_confidence_idx is not None:

        sample_indices = high_confidence_idx.cpu().numpy()
        sample_labels = pseudo_labels_np[sample_indices]
        sample_embeddings = embeddings.cpu().numpy()[sample_indices]
    else:
        
        sample_labels = pseudo_labels_np
        sample_embeddings = embeddings.cpu().numpy()
    
 
    embeddings_3d = tsne.fit_transform(sample_embeddings)
    

    plt.figure(figsize=(16, 12))
    ax = plt.axes(projection='3d')

    colors = ['#66CCCC', '#FF7F0E']  
    custom_cmap = ListedColormap(colors)
    

    ax.set_facecolor('white')
    ax.grid(False)
    ax.set_axis_off() 

    ax.set_box_aspect([1, 1, 0.8])  
    

    from sklearn.metrics import silhouette_score
    

    if len(np.unique(sample_labels)) > 1 and len(sample_labels) > 1:
        silhouette_avg = silhouette_score(embeddings_3d, sample_labels)
    else:
        silhouette_avg = 0
    

    label_0 = np.sum(sample_labels == 0)
    label_1 = np.sum(sample_labels == 1)
    

    scatter = ax.scatter(embeddings_3d[:, 0], embeddings_3d[:, 1], embeddings_3d[:, 2],
               c=sample_labels,
               cmap=custom_cmap,
               s=30, 
               alpha=0.8,  
               marker='o')  
    

    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#66CCCC', 
               markersize=12, label='normal'), 
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#FF7F0E', 
               markersize=12, label='fraudulent'),  
    ]
    
 
    legend_title = f'silhouette: {silhouette_avg:.3f}'
    

    ax.legend(handles=legend_elements, loc='upper right', title=legend_title, fontsize=12, title_fontsize=12)
    
 
    ax.view_init(elev=20, azim=45)
    

    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    

    plt.tight_layout(pad=0)
    

    if overwrite_previous:
        save_file = os.path.join(save_dir, f'{save_path}_latest.png')
    else:
        save_file = os.path.join(save_dir, f'{save_path}.png')
    

    plt.savefig(save_file, dpi=600, bbox_inches='tight', facecolor='white', format='png', 
                transparent=False, pad_inches=0)
    plt.close()
    

def test(idx_eval, y_eval, gnn_model, feat_data, edge_indexs):
    """
    Args:
        idx_eval
        y_eval
        gnn_model
        feat_data
        edge_indexs
    """
    gnn_model.eval()

    logits, _= gnn_model(feat_data, edge_indexs)
    x_softmax = torch.exp(logits).cpu().detach()
    

    valid_indices = []
    valid_y_eval = []
    for i, idx in enumerate(idx_eval):
        if y_eval[i] != 2:  
            valid_indices.append(idx)
            valid_y_eval.append(y_eval[i])
    
    if not valid_indices: 
        return 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.0, 0.0
    
    positive_class_probs = x_softmax[:, 1].numpy()[np.array(valid_indices)]
    

    y_eval_np = np.array(valid_y_eval)
    auc_score = roc_auc_score(y_eval_np, np.array(positive_class_probs))
    

    negative_class_probs = 1 - positive_class_probs
    

    y_eval_label0 = (y_eval_np == 0).astype(int)  
    y_eval_label1 = (y_eval_np == 1).astype(int)  

    auc_score_label0 = roc_auc_score(y_eval_label0, negative_class_probs)
    auc_score_label1 = roc_auc_score(y_eval_label1, positive_class_probs)


    label_prob = (np.array(positive_class_probs) >= 0.5).astype(int)

    acc_overall = accuracy_score(y_eval_np, label_prob)
    
   
    if np.sum(y_eval_np == 0) > 0:  
        acc_label0 = np.sum((y_eval_np == 0) & (label_prob == 0)) / np.sum(y_eval_np == 0)
    else:
        acc_label0 = 0.0
        
    
    if np.sum(y_eval_np == 1) > 0:  
        acc_label1 = np.sum((y_eval_np == 1) & (label_prob == 1)) / np.sum(y_eval_np == 1)
    else:
        acc_label1 = 0.0
    
    ap_score = average_precision_score(np.array(valid_y_eval), np.array(positive_class_probs))
    f1_score_val = f1_score(np.array(valid_y_eval), label_prob, average='macro')
    g_mean = calculate_g_mean(np.array(valid_y_eval), label_prob)


    with torch.no_grad():
        valid_indices_tensor = torch.tensor(valid_indices, device=logits.device)
        valid_y_eval_tensor = torch.tensor(valid_y_eval, device=logits.device, dtype=torch.long)
        valid_logits = logits[valid_indices_tensor].float()  
        sample_losses = F.cross_entropy(valid_logits, valid_y_eval_tensor, reduction='none')
        
     
        loss_variance = torch.var(sample_losses).item()
        loss_std = torch.std(sample_losses).item()

    return auc_score, ap_score, f1_score_val, g_mean, acc_label0, acc_label1, acc_overall, loss_variance, loss_std

def sigmoid_rampup(current, rampup_length):
    '''Exponential rampup from https://arxiv.org/abs/1610.02242'''
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

def get_current_mu(epoch, args):
    if args['mu_rampup']:
        # Consistency ramp-up from https://arxiv.org/abs/1610.02242
        if args['consistency_rampup'] is None:
            #args['consistency_rampup'] = args['num_epochs']
            args['consistency_rampup'] = 100
        return args['mu'] * sigmoid_rampup(epoch, args['consistency_rampup'])
    else:
        return args['mu']

def initialize_centroids(features, k):
    num_nodes = features.size(0)
    centroids = torch.zeros(k, features.size(1), device=features.device)
    

    first_id = torch.randint(num_nodes, (1,)).item()
    centroids[0] = features[first_id]
    
    for i in range(1, k):

        distances = torch.min(torch.cdist(features, centroids[:i]), dim=1)[0]

        probabilities = distances / distances.sum()
        next_id = torch.multinomial(probabilities, 1).item()
        centroids[i] = features[next_id]
    
    return centroids

def check_convergence(centroids, prev_centroids, tol=1e-4):
    return torch.norm(centroids - prev_centroids) < tol


def PseudolabelClusteringSelfcorrection(features, k=2, temperature=0.1, max_iterations=10, labeled_features=None, labeled_classes=None):
    """
    Perform robust node clustering with self-correction mechanism.
    
    Args:
        features: Node features of original graph [num_nodes, feature_dim]
        k: Number of clusters (default 2 for binary classification)
        temperature: Temperature parameter controlling soft assignment
        max_iterations: Maximum number of iterations
        labeled_features: Features of labeled samples [num_labeled, feature_dim]
        labeled_classes: Labels of labeled samples [num_labeled]
    
    Returns:
        tuple: (
            original_cluster_assignments: Cluster assignments for original graph [num_nodes, k]
            view1_cluster_assignments: Cluster assignments for augmented view 1 [num_nodes, k]
            view2_cluster_assignments: Cluster assignments for augmented view 2 [num_nodes, k]
            centroids: Cluster centers [k, feature_dim]
        )
    """
    num_nodes = features.size(0)
    feature_dim = features.size(1)
    device = features.device

    # Clustering iteration process doesn't need gradients
    with torch.no_grad():
        # Check if labeled samples are provided for initializing cluster centers
        if labeled_features is not None and labeled_classes is not None:
            # Initialize cluster centers using labeled samples
            centroids = torch.zeros(k, feature_dim, device=device)
            
            # Group labeled samples by class
            for i in range(k):
                # Find samples with label i
                class_indices = torch.where(labeled_classes == i)[0]
                if len(class_indices) > 0:
                    # If samples exist for this class, use their mean feature as center
                    centroids[i] = labeled_features[class_indices].mean(dim=0)
                else:
                    # If no samples for this class, randomly initialize
                    centroids[i] = torch.randn(feature_dim, device=device)
                    centroids[i] = F.normalize(centroids[i], p=2, dim=0)  # Normalize
                    
            # Normalize cluster centers - ensure they have same norm
            norms = torch.norm(centroids, dim=1, keepdim=True)
            centroids = centroids / (norms + 1e-10)  # Avoid division by zero
            
        else:
            # If no labeled samples provided, use original k-means++ initialization
            centroids = initialize_centroids(features, k)
        
        # Record initial cluster centers for convergence check
        prev_centroids = centroids.clone()
        
        # Only perform iterative optimization when no labeled data is provided
        if labeled_features is None or labeled_classes is None:
            # Iterative optimization - no gradients needed
            for iter in range(max_iterations):
                # Calculate distance from each node to cluster centers - using original graph features
                distances = torch.cdist(features, centroids)  # [num_nodes, k]
                
                # Soft assignment (using Gumbel-Softmax for differentiable cluster assignment)
                logits = -distances / temperature
                cluster_assignments = F.gumbel_softmax(logits, tau=temperature, hard=False)
                
                # Update cluster centers - using original graph features
                new_centroids = torch.zeros_like(centroids)
                for j in range(k):
                    weights = cluster_assignments[:, j].unsqueeze(1)  # [num_nodes, 1]
                    if weights.sum() > 0:  # Avoid division by zero
                        new_centroids[j] = (features * weights).sum(0) / weights.sum()
                    else:
                        new_centroids[j] = centroids[j].clone()  # Keep original center
                
                # Replace old tensors with new ones
                centroids = new_centroids
                    
                # Check convergence
                if check_convergence(centroids, prev_centroids, tol=1e-4):
                    break
                    
                prev_centroids = centroids.clone()
    
    # Calculate final cluster assignments for original graph features
    distances_original = torch.cdist(features, centroids)  # [num_nodes, k]
    logits_original = -distances_original / temperature
    original_cluster_assignments = F.gumbel_softmax(logits_original, tau=temperature, hard=False)
    
    # Use original assignments for both views
    view1_cluster_assignments = original_cluster_assignments
    view2_cluster_assignments = original_cluster_assignments

    # Calculate clustering statistics
    with torch.no_grad():
        hard_assignments = torch.argmax(original_cluster_assignments, dim=1)
        num_class_0 = torch.sum(hard_assignments == 0).item()
        num_class_1 = torch.sum(hard_assignments == 1).item()
        total = num_class_0 + num_class_1
    
    return original_cluster_assignments, view1_cluster_assignments, view2_cluster_assignments, centroids

def compute_pcsc_loss(features, cluster_assignments, centroids, temperature_c=0.5):
    """
    
    Args:
        features:[num_nodes, feature_dim]
        cluster_assignments: [num_nodes, k]
        centroids: [k, feature_dim]
        
    Returns:
        loss
        num_pos
        num_neg
    """
   
    eps = 1e-8
    temperature = temperature_c
    feature_dim = features.size(1)  
    
    with torch.no_grad():
        tempered_assignments = torch.softmax(cluster_assignments / temperature, dim=1)
        hard_assignments = torch.argmax(tempered_assignments, dim=1)
        
        pos_indices = torch.nonzero(hard_assignments != 0).squeeze(-1)
        neg_indices = torch.nonzero(hard_assignments == 0).squeeze(-1)
        
        num_pos = pos_indices.numel()
        num_neg = neg_indices.numel()
        
        total_samples = num_pos + num_neg + eps
        intra_weight = 1.0
        inter_weight = min(0.5, 10.0 / total_samples)
        center_weight = 0.01 * (feature_dim / 16.0)




    features_probs = F.softmax(features, dim=-1)
    centroids_probs_for_nodes = F.softmax(centroids, dim=-1)

    log_features_probs = torch.log(features_probs + eps)
    log_centroids_probs_for_nodes = torch.log(centroids_probs_for_nodes + eps)

    features_probs_expanded = features_probs.unsqueeze(1) # [num_nodes, 1, feature_dim]
    log_features_probs_expanded = log_features_probs.unsqueeze(1) # [num_nodes, 1, feature_dim]
    log_centroids_probs_for_nodes_expanded = log_centroids_probs_for_nodes.unsqueeze(0) # [1, k, feature_dim]
    
    # distances[i,j] = D_KL(features_probs[i] || centroids_probs_for_nodes[j])
    # D_KL(P || Q) = sum(P * (logP - logQ))
    distances = (features_probs_expanded * (log_features_probs_expanded - log_centroids_probs_for_nodes_expanded)).sum(dim=-1)
    

    weighted_distances = tempered_assignments * distances 
    intra_cluster_loss = torch.mean(torch.sum(weighted_distances, dim=1))
    

    centroids_probs = F.softmax(centroids, dim=-1) # Re-calculate or use centroids_probs_for_nodes
    k_centroids = centroids_probs.size(0)
    log_centroids_probs = torch.log(centroids_probs + eps)


    kl_div_matrix_pq = (centroids_probs.unsqueeze(1) * \
                       (log_centroids_probs.unsqueeze(1) - log_centroids_probs.unsqueeze(0))).sum(dim=-1)

    centroid_kl_pairs_list = []
    if k_centroids > 1:
        for r_idx in range(k_centroids):
            for c_idx in range(r_idx + 1, k_centroids):
                kl_rc = kl_div_matrix_pq[r_idx, c_idx]  # D_KL(C_r || C_c)
                kl_cr = kl_div_matrix_pq[c_idx, r_idx]  # D_KL(C_c || C_r)
                sym_kl = (kl_rc + kl_cr) / 2.0
                centroid_kl_pairs_list.append(sym_kl)
    
    if centroid_kl_pairs_list:
        centroid_pairs = torch.stack(centroid_kl_pairs_list)
    else:
        centroid_pairs = torch.empty(0, device=features.device)

    inter_cluster_loss = -torch.logsumexp(-centroid_pairs / temperature, dim=0) if centroid_pairs.numel() > 0 else torch.tensor(0.0, device=features.device)
    
    center_reg = torch.mean(torch.nn.functional.smooth_l1_loss(
        centroids, torch.zeros_like(centroids), reduction='none'
    ))
    
    total_loss = (
        intra_weight * intra_cluster_loss +
        inter_weight * inter_cluster_loss +
        center_weight * center_reg
    )
    
    return total_loss, num_pos, num_neg



def hogrl_main(args):

    if torch.cuda.is_available() and args['gpu_id'] >= 0:
        device = torch.device(f"cuda:{args['gpu_id']}")
    else:
        device = torch.device('cpu')
    

    enable_visualization = args.get('enable_visualization', False)
    
    
    balance_pseudo_labels = args.get('balance_pseudo_labels', False)
    

    use_original_pseudo_labels = args.get('use_original_pseudo_labels',True)
   
    use_clustering_pseudo_labels = args.get('use_clustering_pseudo_labels', True)
    

    model_name = "HOGRL_GradConf" 
    log_file_path = f'logs/{model_name}_{time.strftime("%Y%m%d_%H%M%S")}.log'
    
    def log_and_print(message):
        print(message) 
        with open(log_file_path, 'a') as file:
            file.write(message + '\n')
    
    fixed_cluster_epochs = 10 

    GCAL = GradientConfidenceAware(num_classes=2,
                                                  k_percent=args.get('k_percent', 10),
                                                  gamma_focal=args.get('gamma_focal', 0.8),
                                                  gamma_ga=args.get('gamma_ga', 0.5),
                                                  use_softmax=True).to(device)

    LAP_loss = LogitsAdversarialPerturbation(
        num_classes=2,
        pgd_nums=args.get('pgd_nums', 30),
        alpha=args.get('alpha', 0.05)
    ).to(device)
    

    debug_print = args.get('debug_print', False)  
    print(f"Using device: {device}")
    
    prefix = os.path.join(os.path.dirname(__file__), "..", "..", "data/")
    print('loading data...')
    edge_indexs, feat_data, labels = load_data(args['dataset'], args['layers_tree'], prefix)
    

    labels_tensor = torch.tensor(labels).to(device)
    
    best_model_id = 0
    np.random.seed(args['seed'])
    random.seed(args['seed'])
    
    if args['dataset'] == 'yelp' or args['dataset'] == 'CCFD' or args['dataset'] == 'ffsd':
        assert args['dataset'] != 'CCFD', 'Due to confidentiality agreements, we are unable to provide the CCFD data.'
        
        index = list(range(len(labels)))
        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels, stratify=labels, test_size=args['test_size'], random_state=2, shuffle=True)
        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val, stratify=y_train_val, test_size=args['val_size'], random_state=2, shuffle=True)
        lambda_cl = 0.7
        drop_edge_rate_1 = 0.2
        drop_edge_rate_2 = 0.3
        use_pot = False
        
        original_train = idx_train.copy()  
        

        train_pos, train_neg = pos_neg_split(idx_train, y_train)
        

        np.random.shuffle(train_pos)  
        np.random.shuffle(train_neg)  
        one_pos = [train_pos[0]]  
        one_neg = [train_neg[0]]  
        
        
        convert_to_unlabel = list(set(original_train) - set(one_pos + one_neg))
        train_unlabeled = convert_to_unlabel  
        
        idx_train = one_pos + one_neg
        y_train = labels[idx_train]
 
    
    
        
    elif args['dataset'] == 'amazon':

        labeled_index = list(range(3305, len(labels)))
        lambda_cl = 0.7
        drop_edge_rate_1 = 0.2
        drop_edge_rate_2 = 0.3
        use_pot = True
        idx_train_val, idx_test, y_train_val, y_test = train_test_split(
            labeled_index, 
            labels[3305:], 
            stratify=labels[3305:], 
            test_size=args['test_size'], 
            random_state=2, 
            shuffle=True
        )
        idx_train, idx_val, y_train, y_val = train_test_split(
            idx_train_val, 
            y_train_val, 
            stratify=y_train_val, 
            test_size=args['val_size'], 
            random_state=2, 
            shuffle=True
        )

        unlabeled_pool = list(range(0, 3305))
        original_train = idx_train.copy()  
        

        train_pos, train_neg = pos_neg_split(idx_train, y_train)
        

        np.random.shuffle(train_pos) 
        np.random.shuffle(train_neg) 
        one_pos = [train_pos[0]]  
        one_neg = [train_neg[0]] 
        

        convert_to_unlabel = list(set(original_train) - set(one_pos + one_neg))
        train_unlabeled = unlabeled_pool + convert_to_unlabel  
        
        # 更新训练集和标签
        idx_train = one_pos + one_neg
        y_train = labels[idx_train]
    
    
    train_pos, train_neg = pos_neg_split(idx_train, y_train)
    
    def nt_xent_loss(z_i, z_j, temperature=0.05):
            """
            NT-Xent Loss (Normalised Temperature-scaled Cross Entropy Loss)
            
            :param z_i: Tensor, representations of the first augmented view.
            :param z_j: Tensor, representations of the second augmented view.
            :param temperature: Float, temperature scaling factor for the loss function.
            """
            # Normalize the feature vectors
            z_i = F.normalize(z_i, dim=-1)
            z_j = F.normalize(z_j, dim=-1)
            
            # Concatenate the features from both views
            representations = torch.cat([z_i, z_j], dim=0)
            
            # Compute similarity matrix
            sim_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
            
            # Create labels for positive and negative pairs
            labels = torch.cat([torch.arange(z_i.size(0)).to(device) for _ in range(2)], dim=0)
            masks = labels[:, None] == labels[None, :]
            
            # Mask out self-similarity terms
            mask_diag = ~torch.eye(labels.size(0), dtype=torch.bool).to(device)
            sim_matrix = sim_matrix[mask_diag].view(labels.size(0), -1)
            masks = masks[mask_diag].view(labels.size(0), -1)
            
            # Compute the InfoNCE loss
            nominator = torch.exp(sim_matrix / temperature)[masks].view(labels.size(0), -1).sum(dim=-1)
            denominator = torch.sum(torch.exp(sim_matrix / temperature), dim=-1)
            loss = -torch.log(nominator / denominator).mean()
            
            return loss

    def generate_contrastive_pairs(batch_nodes, labels, feat_data):
        """
        Generate positive and negative sample pairs based on given batch nodes.
        
        Args:
            batch_nodes: List of node indices in current batch
            labels: Node labels
            feat_data: Node feature data
        Returns:
            tuple: (positive_pairs, negative_pairs)
        """
        positive_pairs = []
        negative_pairs = []
        
        # Move CUDA tensors to CPU and convert to NumPy arrays
        if isinstance(labels, torch.Tensor) and labels.is_cuda:
            labels_cpu = labels.cpu().numpy()
        else:
            labels_cpu = labels

        # Ensure batch_nodes is on CPU
        if isinstance(batch_nodes, torch.Tensor) and batch_nodes.is_cuda:
            batch_nodes_cpu = batch_nodes.cpu().numpy()
        else:
            batch_nodes_cpu = batch_nodes

        for node in batch_nodes_cpu:
            # Positive pairs: nodes with same class label
            same_class_nodes = np.where(labels_cpu == labels_cpu[node])[0]
            if len(same_class_nodes) > 1:
                pos_pair = np.random.choice(same_class_nodes[same_class_nodes != node], 1)[0]
                positive_pairs.append((node, pos_pair))

            # Negative pairs: nodes with different class labels
            diff_class_nodes = np.where(labels_cpu != labels_cpu[node])[0]
            if len(diff_class_nodes) > 0:
                neg_pair = np.random.choice(diff_class_nodes, 1)[0]
                negative_pairs.append((node, neg_pair))
        
        return positive_pairs, negative_pairs



    gnn_model_1 = multi_HOGRL_Model(     
        in_feat=feat_data.shape[1], 
        out_feat=2, 
        relation_nums=len(edge_indexs),
        hidden=args['emb_size'], 
        drop_rate=args['drop_rate'],
        weight=args['weight'], 
        num_layers=args['layers'],
        layers_tree=args['layers_tree'],
        temperature=0.5,
        dataset=args['dataset']
    ).to(device)
    

    for edge_index in edge_indexs:
        edge_index[0] = edge_index[0].to(device)
        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]

    feat_data = torch.tensor(feat_data).float().to(device)

    optimizer_1 = torch.optim.Adam(
        list(gnn_model_1.parameters()), 
        lr=0.002, 
        weight_decay=3e-5
    )

   
    batch_size = args['batch_size']

    best_val_auc = 0.0
    best_model_state = None
    best_test_auc = 0.0
    
    print('generating augmented views...')
    aug_type1 = 'degree'
    aug_type2 = 'degree' 

    feat_data1, edge_index_1 = get_augmented_view(
        edge_indexs,
        feat_data,
        aug_type=aug_type1, 
        drop_rate=drop_edge_rate_1
    )
    
    feat_data2, edge_index_2 = get_augmented_view(
        edge_indexs,
        feat_data,
        aug_type=aug_type2, 
        drop_rate=drop_edge_rate_2
    )
    

    for edge_index in edge_index_1:
        edge_index[0] = edge_index[0].to(device)
        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]
    for edge_index in edge_index_2:
        edge_index[0] = edge_index[0].to(device)
        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]
    print('training...')

    if 'mu' not in args:
        args['mu'] = 1.5  
    if 'mu_rampup' not in args:
        args['mu_rampup'] = True  
    if 'consistency_rampup' not in args:
        args['consistency_rampup'] = None  
    if 'overwrite_viz' not in args:
        args['overwrite_viz'] = False  

    if 'clustering_temperature' not in args:
        args['clustering_temperature'] = 0.8 



    log_and_print("=" * 50)
    log_and_print(f"Experiment start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    log_and_print("=" * 50)
    log_and_print("Experiment parameters:")
    for key, value in sorted(args.items()):
        log_and_print(f"  {key}: {value}")
    log_and_print("=" * 50)
    log_and_print(f"Dataset: {args['dataset']}")
    log_and_print(f"Training set size: {len(idx_train)}")
    log_and_print(f"Validation set size: {len(idx_val)}")
    log_and_print(f"Test set size: {len(idx_test)}")
    log_and_print(f"Positive sample ratio: {sum(y_train)/len(y_train):.4f}")
    log_and_print(f"Augmentation types: {aug_type1}, {aug_type2}")
    log_and_print(f"Edge drop rates: {drop_edge_rate_1}, {drop_edge_rate_2}")
    log_and_print(f"Pseudo label balance: {'Enabled' if balance_pseudo_labels else 'Disabled'}")
    log_and_print("=" * 50)

    epoch_list = []  
    total_loss_list = [] 
    sup_loss_list = []  
    cls_loss_list = []  
    pcsc_loss_list = [] 
    gcal_loss_list = []  
    lap_loss_list = []  
    consistency_loss_list = []

    epoch_sup_loss = 0
    epoch_cls_loss = 0
    epoch_consistency_loss = 0
    epoch_pcsc_loss = 0
    epoch_total_loss = 0

    epoch_pseudo_pos_count = 0
    pcsc_loss = 0

    
    
    sampled_idx_train = idx_train.copy() 
        
    rd.shuffle(train_unlabeled)
        
   
    pos_samples = [idx for idx in sampled_idx_train if labels[idx] == 1]  
    neg_samples = [idx for idx in sampled_idx_train if labels[idx] == 0] 
    for epoch in range(args['num_epochs']):
        
        
        gnn_model_1.train()
        loss = 0
        
        
        current_mu = get_current_mu(epoch, args)
        
        
        if debug_print:
            print(f"\nEpoch {epoch} start:")
            print(f"Total unlabeled data: {len(train_unlabeled)}")
            print(f"Total labeled data: {len(sampled_idx_train)}")
            print(f"Number of positive samples: {len(pos_samples)}")  # Should be 1
            print(f"Number of negative samples: {len(neg_samples)}")  # Should be 1
        

        epoch_sup_loss = 0
        epoch_cls_loss = 0
        epoch_consistency_loss = 0
        epoch_pcsc_loss = 0
        epoch_gcal_loss = 0
        epoch_lap_loss = 0
        epoch_total_loss = 0

        epoch_pseudo_pos_count = 0
        epoch_pseudo_neg_count = 0
        num_batches = 0

        if debug_print:
            print(f"Total number of batches: {num_batches}")


        batch_size = args['batch_size']
        num_batches = max(1, (len(train_unlabeled) + batch_size - 1) // batch_size)
        
        if debug_print:
            print(f"Total number of batches: {num_batches}")

        for batch in range(num_batches):
       
            batch_nodes = []
            

            batch_nodes.extend(sampled_idx_train)
            

            remaining_spots = batch_size - len(sampled_idx_train)
            if remaining_spots > 0:
                u_start = batch * remaining_spots
                u_end = min((batch + 1) * remaining_spots, len(train_unlabeled))
                if u_start < len(train_unlabeled):
                    batch_unlabeled = train_unlabeled[u_start:u_end]
   
                    if len(batch_unlabeled) < remaining_spots:
                        needed = remaining_spots - len(batch_unlabeled)
                        batch_unlabeled.extend(train_unlabeled[:needed])
                    batch_nodes.extend(batch_unlabeled)
            

            batch_labeled = [node for node in batch_nodes if node in sampled_idx_train]
            batch_unlabeled = [node for node in batch_nodes if node not in sampled_idx_train]
            
            
            unlabeled_nodes_tensor = torch.tensor(batch_unlabeled, device=device)
            batch_nodes_tensor = torch.tensor(batch_nodes, dtype=torch.long, device=device)
            batch_label = torch.tensor(labels[np.array(batch_labeled)]).long().to(device)

            original_out, original_h = gnn_model_1(feat_data, edge_indexs)

            out1, h1 = gnn_model_1(feat_data1, edge_index_1)
            out2, h2 = gnn_model_1(feat_data2, edge_index_2)



            labeled_nodes_tensor = torch.tensor(batch_labeled, device=device)
            
            sup_loss_1 = F.nll_loss(out1[labeled_nodes_tensor], batch_label)
            sup_loss_2 = F.nll_loss(out2[labeled_nodes_tensor], batch_label)


            positive_pairs, negative_pairs = generate_contrastive_pairs(batch_labeled, labels, feat_data[np.array(batch_labeled)])

            z_i_1 = h1[torch.tensor([p[0] for p in positive_pairs], device=device)]

            z_j_1 = h1[torch.tensor([p[1] for p in positive_pairs], device=device)]

            z_i_2 = h2[torch.tensor([p[0] for p in positive_pairs], device=device)]

            z_j_2 = h2[torch.tensor([p[1] for p in positive_pairs], device=device)]
    
            cls_loss_1 = nt_xent_loss(z_i_1, z_j_1)
            cls_loss_2 = nt_xent_loss(z_i_2, z_j_2)

            consistency_loss = F.mse_loss(h1[batch_nodes_tensor], h2[batch_nodes_tensor])

            if len(batch_nodes) > 0:  
                
                h_orig_unlabeled = original_h[unlabeled_nodes_tensor]
                h1_unlabeled = h1[unlabeled_nodes_tensor]
                h2_unlabeled = h2[unlabeled_nodes_tensor]
                
                
                labeled_nodes_tensor = torch.tensor(batch_labeled, device=device)
                labeled_features_orig = original_h[labeled_nodes_tensor]
                labeled_classes = torch.tensor(labels[np.array(batch_labeled)], device=device)
                
                
                if epoch < fixed_cluster_epochs and use_clustering_pseudo_labels:
                    
                   
                    cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = PseudolabelClusteringSelfcorrection(
                        h_orig_unlabeled,  
                        k=2, 
                        temperature=args["clustering_temperature"],
                        max_iterations=10,
                        labeled_features=labeled_features_orig,  
                        labeled_classes=labeled_classes      
                    )

                elif use_clustering_pseudo_labels:
                    
                    if epoch == fixed_cluster_epochs and batch == 0:
                        log_and_print(f"\n【Epoch {epoch}】clustering mode")
                    
                    cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = PseudolabelClusteringSelfcorrection(
                        h_orig_unlabeled, 
                        k=2, 
                        temperature=args["clustering_temperature"],
                        max_iterations=10
                    )

           

                
                all_features = torch.cat([h_orig_unlabeled, h1_unlabeled, h2_unlabeled], dim=0)
                all_assignments = torch.cat([cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2], dim=0)
                
                
                pcsc_loss, num_pos_all, num_neg_all = compute_pcsc_loss(
                    all_features, 
                    all_assignments, 
                    centroids_orig,
                    temperature_c=args["temperature_c"] 
                )

                with torch.no_grad():
                    final_pseudo_labels_for_batch_unlabeled = torch.tensor([], dtype=torch.long, device=device)
                    
                    
                    if use_original_pseudo_labels and use_clustering_pseudo_labels:
                        
                        orig_logits_unlabeled = original_out[unlabeled_nodes_tensor]
                        orig_probs_unlabeled = F.softmax(orig_logits_unlabeled, dim=1) 
                        temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) 
                        count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                        count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()
                        
                        aligned_cluster_probs = cluster_assignments_orig.clone()
                        if count_c0 < count_c1: 

                            aligned_cluster_probs[:, 0] = cluster_assignments_orig[:, 1] 
                            aligned_cluster_probs[:, 1] = cluster_assignments_orig[:, 0] 
                        
                        
                        combined_probs_unlabeled = (aligned_cluster_probs+orig_probs_unlabeled)/2
                        final_pseudo_labels_for_batch_unlabeled = torch.argmax(combined_probs_unlabeled, dim=1)

                    elif use_clustering_pseudo_labels:
                        
                        temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) 
                        count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                        count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()

                        if count_c0 >= count_c1: 
                            
                            final_pseudo_labels_for_batch_unlabeled = temp_cluster_hard_labels
                        else: 
                            
                            final_pseudo_labels_for_batch_unlabeled = 1 - temp_cluster_hard_labels
                    
                    elif use_original_pseudo_labels:
                        
                        orig_logits_unlabeled = original_out[unlabeled_nodes_tensor]
                        final_pseudo_labels_for_batch_unlabeled = torch.argmax(orig_logits_unlabeled, dim=1)
                    
                    
                    if final_pseudo_labels_for_batch_unlabeled.numel() > 0:
                        
                        consistent_high_conf_indices = torch.arange(final_pseudo_labels_for_batch_unlabeled.size(0), device=device)
                        consistent_pseudo_labels = final_pseudo_labels_for_batch_unlabeled
                        
                        
                        epoch_pseudo_pos_count += torch.sum(consistent_pseudo_labels == 1).item()
                        epoch_pseudo_neg_count += torch.sum(consistent_pseudo_labels == 0).item()
                        

                    else:
                        consistent_high_conf_indices = torch.tensor([], dtype=torch.long, device=device)
                        consistent_pseudo_labels = torch.tensor([], dtype=torch.long, device=device) # Ensure consistent_pseudo_labels is defined
                    
                    
                    num_consistent_high_conf = consistent_pseudo_labels.numel()

                
                gcal_loss = torch.tensor(0.0, device=device)
                lap_loss = torch.tensor(0.0, device=device)  

                if num_consistent_high_conf > 0:
                    try:
                        
                        pseudo_logits_1 = out1[unlabeled_nodes_tensor][consistent_high_conf_indices]
                        pseudo_logits_2 = out2[unlabeled_nodes_tensor][consistent_high_conf_indices]
                        
                       
                        if pseudo_logits_1.size(0) == consistent_pseudo_labels.size(0) and \
                           pseudo_logits_2.size(0) == consistent_pseudo_labels.size(0):
                            
                            
                            pos_samples = torch.sum(consistent_pseudo_labels == 1).item()
                            total_samples = consistent_pseudo_labels.size(0)
                            
                            
                            if total_samples > 0:
                                
                                if epoch < -1:
                                    
                                    pseudo_label_loss_1 = F.cross_entropy(pseudo_logits_1, consistent_pseudo_labels)
                                    pseudo_label_loss_2 = F.cross_entropy(pseudo_logits_2, consistent_pseudo_labels)
                                else:
                                    #GradientConfidenceAware

                                
                                    pseudo_label_loss_1 = GCAL(
                                        pseudo_logits_1, 
                                        consistent_pseudo_labels
                                    )
                                    
                                    
                                    pseudo_label_loss_2 = GCAL(
                                        pseudo_logits_2, 
                                        consistent_pseudo_labels
                                    )
                                
                                
                                gcal_loss = (pseudo_label_loss_1 + pseudo_label_loss_2) / 2
                                
                                
                                class_counts = [torch.sum(consistent_pseudo_labels == 0).item(), 
                                              torch.sum(consistent_pseudo_labels == 1).item()]
                                

                                if min(class_counts) > 0:
    
                                    adap_lpl_loss_1, _, _= LAP_loss(pseudo_logits_1, None, consistent_pseudo_labels, is_logits=True)
                                    
                                   
                                    adap_lpl_loss_2, _, _ = LAP_loss(pseudo_logits_2, None, consistent_pseudo_labels, is_logits=True)
                                    
                                    
                                    lap_loss = (adap_lpl_loss_1 + adap_lpl_loss_2) / 2
                            
                   

                            else:
                                log_and_print(f"error!")
                    except Exception as e:
                        log_and_print(f"error: {e}")
                        gcal_loss = torch.tensor(0.0, device=device)
                        lap_loss = torch.tensor(0.0, device=device)

            else:
                
                pcsc_loss = torch.tensor(0.0, device=device)
                gcal_loss = torch.tensor(0.0, device=device)
            
            
            total_loss = (sup_loss_1 + sup_loss_2) / 2 + \
                        (cls_loss_1 + cls_loss_2) / 2 + \
                        current_mu * consistency_loss + \
                        current_mu * gcal_loss + \
                        current_mu * pcsc_loss +\
                        current_mu * lap_loss

            
            sup_loss = (sup_loss_1 + sup_loss_2) / 2
            cls_loss = (cls_loss_1 + cls_loss_2) / 2
            
            
            epoch_sup_loss += sup_loss.item()
            epoch_cls_loss += cls_loss.item()
            epoch_consistency_loss += consistency_loss.item()
            epoch_pcsc_loss += pcsc_loss.item() if isinstance(pcsc_loss, torch.Tensor) else pcsc_loss
            epoch_gcal_loss += gcal_loss.item() if isinstance(gcal_loss, torch.Tensor) else gcal_loss  
            epoch_lap_loss += lap_loss.item() if isinstance(lap_loss, torch.Tensor) else lap_loss
            epoch_total_loss += total_loss.item()
            
        
            num_batches += 1
            
            optimizer_1.zero_grad()
            total_loss.backward()
            optimizer_1.step()
            
            loss += total_loss.item()

        
        epoch_sup_loss /= num_batches
        epoch_cls_loss /= num_batches
        epoch_consistency_loss /= num_batches
        epoch_pcsc_loss /= num_batches
        epoch_lap_loss /= num_batches
        epoch_total_loss /= num_batches
        
        
        log_and_print(f'\nEpoch {epoch} Loss Summary:')
        log_and_print(f'  Supervised Loss: {epoch_sup_loss:.4f}')
        log_and_print(f'  Contrastive Loss: {epoch_cls_loss:.4f}')
        log_and_print(f'  Consistency Loss: {epoch_consistency_loss:.4f}')
        log_and_print(f'  PCSC Loss: {epoch_pcsc_loss:.4f}')
        log_and_print(f'  GCAL Loss: {epoch_gcal_loss:.4f}')
        log_and_print(f'  LAP Loss: {epoch_lap_loss:.4f}')
        log_and_print(f'  Total Loss: {epoch_total_loss:.4f}')
        
        epoch_list.append(epoch)
        sup_loss_list.append(epoch_sup_loss)
        cls_loss_list.append(epoch_cls_loss)
        pcsc_loss_list.append(epoch_pcsc_loss)
        gcal_loss_list.append(epoch_gcal_loss)
        lap_loss_list.append(epoch_lap_loss)
        consistency_loss_list.append(epoch_consistency_loss)
        total_loss_list.append(epoch_total_loss)
        

        
        if epoch % 1 == 0:  
            val_auc, val_ap, val_f1, val_g_mean, val_acc_label0, val_acc_label1, val_acc_overall, val_loss_var, val_loss_std = test(idx_val, y_val, gnn_model_1, feat_data, edge_indexs)
            log_and_print(f'Epoch: {epoch}, Validation AUC: {val_auc:.4f}, AP: {val_ap:.4f}, F1: {val_f1:.4f}, G-mean: {val_g_mean:.4f}, Label 0 ACC: {val_acc_label0:.4f}, Label 1 ACC: {val_acc_label1:.4f}, Overall ACC: {val_acc_overall:.4f}')

           
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                best_model_state = gnn_model_1.state_dict()

        
        if epoch == args['num_epochs'] - 1:
            log_and_print(f'Final validation AUC: {best_val_auc:.4f}')

        
        gnn_model_1.load_state_dict(best_model_state)  
        test_auc, test_ap, test_f1, test_g_mean, test_acc_label0, test_acc_label1, test_acc_overall, test_loss_var, test_loss_std = test(idx_test, y_test, gnn_model_1, feat_data, edge_indexs)
        
        log_and_print(f'Test results: AUC={test_auc:.4f}, AP={test_ap:.4f}, F1={test_f1:.4f}, G-mean={test_g_mean:.4f}, Best test AUC: {best_test_auc:.4f}, Label 0 ACC: {test_acc_label0:.4f}, Label 1 ACC: {test_acc_label1:.4f}, Overall ACC: {test_acc_overall:.4f}')
    
        
        
        if test_auc > best_test_auc:
            best_test_auc = test_auc
        
        if enable_visualization:  
            with torch.no_grad():
                
                test_tensor = torch.tensor(idx_test, device=device)
                _, h_test = gnn_model_1(feat_data, edge_indexs)
                h_test = h_test[test_tensor].clone().detach()
                
                
                test_labels = torch.tensor(y_test, device=device)
                
                
                args['overwrite_viz'] = True
                
                
                auc_path = f'test_viz_{args["dataset"]}_epoch{epoch}_auc{test_auc:.4f}'

                visualize_clustering(
                    embeddings=h_test,
                    pseudo_labels=test_labels,
                    high_confidence_idx=None,  
                    epoch=epoch,
                    save_path=auc_path,
                    overwrite_previous=args['overwrite_viz']
                )
            
    log_and_print("\n" + "=" * 50)
    log_and_print(f"Experiment end time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    log_and_print(f"Best validation AUC: {best_val_auc:.4f}")
    log_and_print(f"Best test AUC: {best_test_auc:.4f}")
    log_and_print(f"Best test ACC: {test_acc_overall:.4f}")
    log_and_print(f"Best test negative (label 0) ACC: {test_acc_label0:.4f}")
    log_and_print(f"Best test positive (label 1) ACC: {test_acc_label1:.4f}")
    log_and_print(f"Final test results: AUC={test_auc:.4f}, AP={test_ap:.4f}, F1={test_f1:.4f}, G-mean={test_g_mean:.4f}, Label 0 ACC: {test_acc_label0:.4f}, Label 1 ACC: {test_acc_label1:.4f}, Overall ACC: {test_acc_overall:.4f}")
            
    log_and_print("=" * 50)
    log_and_print(f"Log file saved at: {log_file_path}")
    log_and_print("=" * 50)

    out, embedding = gnn_model_1(feat_data, edge_indexs)
    if enable_visualization:  
        print('Generating embedding visualization...')
        Visualization(labels, embedding.cpu().detach(), prefix)


